from distributed_pcg.utils import read_dataset
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import sys
import cvxpy as cp



def create_data(n,d,alpha,seed):
    torch.manual_seed(seed)
    A = torch.randn(size=(n,d))
    U,_,Vt = torch.linalg.svd(A,full_matrices=False)
    alpha = torch.diag(torch.Tensor([(alpha)**i for i in range(n)]))
    A = U@(alpha)**0.5@Vt
    x  = torch.randn(d)
    y = (torch.sign(A@x+100*torch.randn(n))+1)/2
    
    return A, y

class obj_func():
    def __init__(self,A,y,n,l,q):
        self.A = A 
        self.y = y 
        self.n = n
        self.l = l
        self.q = q
    def eval(self,x):
        m = torch.nn.LogSigmoid()
        obj1 = -torch.multiply(self.y,m(self.A@x)) 
        obj2 = -torch.multiply(1-self.y,m(-self.A@x))
        obj = torch.sum(obj1+obj2)/self.n
        reg = self.l*torch.linalg.norm(x)**2
        return obj + reg
    def grad(self,x):
        c1 = torch.diag((torch.sigmoid(self.A@x)-self.y))
        obj_g = torch.mm(c1, self.A)
        obj_g = obj_g.sum(axis=0)
        reg_g = 2*self.l*x 
        return obj_g.squeeze()/self.n + reg_g
    def hessian(self,x):
        obj1 = torch.multiply(torch.sigmoid(self.A@x),torch.sigmoid(-self.A@x))
        c1 = torch.mm(torch.diag(obj1), self.A)
        obj_h = torch.mm(self.A.t(), c1)
        return obj_h/self.n + 2*self.l*torch.eye(self.A.shape[1])
    def hessian_main(self,x):
        obj1 = torch.multiply(torch.sigmoid(self.A@x),torch.sigmoid(-self.A@x))
        c1 = torch.mm(torch.diag(obj1), self.A)
        obj_h = torch.mm(self.A.t(), c1)
        return obj_h/self.n
    def hessian_inv(self,x,method='sample'):
        q = self.q
        if method=='sample':
            h = self.hessian_main(x)
            d = h.shape[0]
            dl = torch.trace(h@torch.linalg.inv(h+2*self.l*torch.eye(d)))
            m = 4*(int(dl)+1)
            hessian_inv = 0
            for _ in range(q):
                S = torch.randn(size=(m,self.n))/m**0.5 
                obj1 = torch.multiply(torch.sigmoid(self.A@x),torch.sigmoid(-self.A@x))
                A_half = self.A.T@torch.diag(obj1**0.5)
                hessian_inv += torch.linalg.inv(A_half@S.T@S@A_half.T/self.n+2*self.l*torch.eye(self.A.shape[1]))
            return hessian_inv/q 
        if method=='sample_debias':
            m = 2
            d = self.A.shape[1]
            H = self.hessian_main(x)
            while m<d:
                S = torch.randn(size=(m,d))/m**0.5 
                z = -5*(2*self.l)/12
                Sm = torch.trace(torch.linalg.inv(S@H@S.T-z*torch.eye(m)))/m
                if (Sm)>(1/(2*self.l)):
                    break
                else: 
                    m = 2*m
            eff_dim = torch.trace(H@torch.inverse(H+2*self.l*torch.eye(d)))

            def Sm(lam):
                sketch_dim  = S.shape[0]
                return torch.trace(torch.linalg.inv(S@H@S.T+lam*torch.eye(sketch_dim)))/sketch_dim

            init_range = torch.Tensor([5*(2*self.l)/12,(2*self.l)])
            assert Sm(init_range[0])>=1/(2*self.l)
            assert Sm(init_range[1])<=1/(2*self.l)
            iter = 0
            while torch.abs(Sm(init_range.mean())-1/(2*self.l))>0.1:
                if iter>50:
                    break
                if Sm(init_range.mean())>1/(2*self.l):
                    init_range[0] = init_range.mean()
                elif Sm(init_range.mean())<1/(2*self.l):
                    init_range[1] = init_range.mean()
                iter +=1
                
            de_e = init_range.mean()/(2*self.l)

            hessian_inv = 0
            for _ in range(q):
                S = torch.randn(size=(m,self.n))/m**0.5 
                obj1 = torch.multiply(torch.sigmoid(self.A@x),torch.sigmoid(-self.A@x))
                A_half = self.A.T@torch.diag(obj1**0.5)
                hessian_inv += torch.linalg.inv(de_e*A_half@S.T@S@A_half.T/self.n+2*self.l*torch.eye(self.A.shape[1]))
            return hessian_inv/q
        if method=='feature':
            d = self.A.shape[1]
            h = self.hessian_main(x)
            dl = torch.trace(h@torch.linalg.inv(h+2*self.l*torch.eye(d)))
            m = 4*(int(dl)+1)
            hessian_inv = 0
            for _ in range(q):
                S = torch.randn(size=(m,d))/m**0.5 
                hessian_inv += S.T@torch.linalg.inv(S@h@S.T+2*self.l*torch.eye(m))@S
            return hessian_inv/q 
        
        if method=='feature_debias':
            m = 2
            n = self.A.shape[0]
            d = self.A.shape[1]
            H = self.hessian_main(x)
            while m<d:
                S = torch.randn(size=(m,d))/(m**0.5)
                z = -5*(2*self.l)/12
                Sm = torch.trace(torch.linalg.inv(S@H@S.T-z*torch.eye(m)))/m
                if (Sm)>(1/(2*self.l)):
                    break
                else: 
                    m = 2*m

            eff_dim = torch.trace(H@torch.inverse(H+2*self.l*torch.eye(d)))

            def Sm(lam):
                sketch_dim  = S.shape[0]
                return torch.trace(torch.linalg.inv(S@H@S.T+lam*torch.eye(sketch_dim)))/sketch_dim

            init_range = torch.Tensor([5*(2*self.l)/12,2*self.l])
            assert Sm(init_range[0])>=1/(2*self.l)
            assert Sm(init_range[1])<=1/(2*self.l)
            iter = 0
            while torch.abs(Sm(init_range.mean())-1/(2*self.l))>0.1:
                if iter>50:
                    break
                if Sm(init_range.mean())>1/(2*self.l):
                    init_range[0] = init_range.mean()
                elif Sm(init_range.mean())<1/(2*self.l):
                    init_range[1] = init_range.mean()
                iter += 1
                
            hat_l = init_range.mean()
            h = self.hessian_main(x)
            dl = torch.trace(h@torch.linalg.inv(h+2*self.l*torch.eye(d)))
            tilde_l = (2*self.l)*(1-dl/m)

        
            hessian_inv = 0
            for _ in range(q):
                S = torch.randn(size=(m,d))/m**0.5 
                hessian_inv += S.T@torch.linalg.inv(S@H@S.T+hat_l*torch.eye(m))@S
            return hessian_inv/q 

def line_search(func,x,v):
    t = 100 
    while func.eval(x+t*v)>func.eval(x)+(0.25)*t*func.grad(x)@v:
        t = t*0.5 
        if t<1e-20:
            raise Exception('line search error')
    return t

def newton(A,y,n,l,cvx_opt,max_iter,q,method='sample'):
    if method == 'sample':
        record = []
        f = obj_func(A,y,n,l,q)
        x = torch.zeros(A.shape[1])
        iter = 0
        MAX_ITER = max_iter
        while iter<MAX_ITER:
            record.append((iter,f.eval(x)))
            v = -f.hessian_inv(x, method='sample')@f.grad(x)
            t = line_search(f,x,v)
            x += t*v 
            iter += 1      
        return record
    
    if method == 'sample_debias':
        record = []
        f = obj_func(A,y,n,l,q)
        x = torch.zeros(A.shape[1])
        iter = 0
        MAX_ITER = max_iter
        while iter<MAX_ITER:
            record.append((iter,f.eval(x)))
            v = -f.hessian_inv(x, method='sample_debias')@f.grad(x)
            t = line_search(f,x,v)
            x += t*v 
            iter += 1
        return record
    
    if method == 'feature':
        record = []
        f = obj_func(A,y,n,l,q)
        x = torch.zeros(A.shape[1])
        iter = 0
        MAX_ITER = max_iter
        while iter<MAX_ITER:
            record.append((iter,f.eval(x)))
            v = -f.hessian_inv(x, method='feature')@f.grad(x)
            t = line_search(f,x,v)
            x += t*v 
            iter += 1
        return record
    
    if method == 'feature_debias':
        record = []
        f = obj_func(A,y,n,l,q)
        x = torch.zeros(A.shape[1])
        iter = 0
        MAX_ITER = max_iter
        while iter<MAX_ITER:
            record.append((iter,f.eval(x)))
            v = -f.hessian_inv(x, method='feature_debias')@f.grad(x)
            t = line_search(f,x,v)
            x += t*v 
            iter += 1
        return record


def get_x_axis(data_set):
    x_max = -1
    length = 0
    for i in data_set: 
        if i[-1][0]>x_max:
            length = len(i)
            x_max = i[-1][0]
    if length == 2:
        return np.linspace(0, x_max, 2)
    return np.linspace(0, x_max, max(length,100))

def interpolate(data, axis, optimal):
    numbers = np.zeros((len(data), len(axis)))
    for i in range(len(data)):
        numbers[i] = get_numbers(data[i], axis, optimal[i]) 
    assert numbers.shape == (len(data), len(axis))
    mean = np.quantile(numbers, 0.5, axis=0)
    error_l = np.quantile(numbers, 0.2, axis=0)
    error_u = np.quantile(numbers, 0.8, axis=0)  
    return (mean, error_l, error_u) 

def get_numbers(data, axis, optimal=None): 
    data = np.array(torch.Tensor(data))
    if optimal is None:
        y = np.abs(data[:,1]-data[-1,1])/np.abs(data[-1,1])
    else:
        optimal = np.array(optimal)
        y = np.abs(data[:,1]-optimal)
    res = np.interp(axis,data[:,0],y)
    return res

def plot_multi_realdata():
    n=100; d=1000; l=1e-3; q=300; alpha=0.99
    sample = []
    sample_debias = []
    x_axis = {}
    optimals = []
    for i in range(10):
        print(i)
        A, y = create_data(n,d,alpha,seed=i)
        max_iter=10     
         
        # cvx check
        x_cvx = cp.Variable(d)
        prob = cp.Problem(cp.Minimize((y@cp.logistic(-A@x_cvx)+(1-y)@cp.logistic(A@x_cvx))/n+l*cp.sum_squares(x_cvx)))
        prob.solve(solver='CLARABEL')
        cvx_opt = prob.value
        optimals.append(cvx_opt)

        k = newton(A,y,n,l,cvx_opt,max_iter,q,method='feature')
        k1 = newton(A,y,n,l,cvx_opt,max_iter,q,method='feature_debias')
        sample.append(k)
        sample_debias.append(k1)

    x_axis['sample'] = get_x_axis(sample)
    x_axis['sample_debias'] = get_x_axis(sample_debias)
    plot_data = {}
    plot_data['sample'] = interpolate(sample, x_axis['sample'], optimals)
    plot_data['sample_debias'] = interpolate(sample_debias, x_axis['sample_debias'], optimals)
    plt.figure(figsize=(100, 100))
    fig, ax = plt.subplots()
    clrs = sns.color_palette("husl", 10)
    ax.plot(x_axis['sample'], plot_data['sample'][0], label='sample', c=clrs[7])
    ax.fill_between(x_axis['sample'], plot_data['sample'][1], plot_data['sample'][2],alpha=0.3, facecolor=clrs[7])
    ax.plot(x_axis['sample_debias'], plot_data['sample_debias'][0], label='sample_debias', c=clrs[4])
    ax.fill_between(x_axis['sample_debias'], plot_data['sample_debias'][1], plot_data['sample_debias'][2],alpha=0.3, facecolor=clrs[4])
    ax.set_yscale('log')
    plt.xlim(left=0,right=6)
    plt.ylim(bottom=1e-7)
    plt.xlabel('Newton Steps', fontsize=20)
    plt.ylabel('Log Optimality Gap', fontsize=20)
    plt.xticks(fontsize=17)
    plt.yticks(fontsize=17)
    plt.tight_layout()
    plt.savefig('gaussian_logistic.pdf')




if __name__ == '__main__':
    plot_multi_realdata()

